import json
from typing import Any, Dict, List, Union

import qianfan
from qianfan import ChatCompletion, Completion, QfResponse
from qianfan.common import Prompt
from qianfan.evaluation.consts import (
    LocalJudgeEvaluatorPromptTemplate,
    QianfanRefereeEvaluatorDefaultMaxScore,
    QianfanRefereeEvaluatorDefaultMetrics,
    QianfanRefereeEvaluatorDefaultSteps,
)
from qianfan.evaluation.evaluator import LocalEvaluator
from qianfan.utils.pydantic import Field


class LocalJudgeEvaluator(LocalEvaluator):
    """local judge evaluator class"""

    model: Union[ChatCompletion, Completion] = Field(
        default=None, description="model object"
    )

    metric_name: str = Field(default="", description="metric name for evaluation")

    model_kwargs: Dict[str, Any] = Field(default={}, description="parameters for model")
    evaluation_prompt: Prompt = Field(
        default=Prompt(LocalJudgeEvaluatorPromptTemplate),
        description="concrete evaluation prompt string",
    )
    criteria: str = Field(
        default=QianfanRefereeEvaluatorDefaultMetrics, description="evaluation metrics"
    )
    prompt_steps: str = Field(
        default=QianfanRefereeEvaluatorDefaultSteps, description="evaluation steps"
    )
    prompt_max_score: int = Field(
        default=QianfanRefereeEvaluatorDefaultMaxScore,
        description="max score for evaluation",
    )

    class Config:
        arbitrary_types_allowed = True

    def evaluate(
        self, input: Union[str, List[Dict[str, Any]]], reference: str, output: str
    ) -> Dict[str, Any]:
        """
        use model to evaluate in local
        Args:
            input (Union[str, List[Dict[str, Any]]]):
                given prompts.
                when is_chat in evaluateManager.eval() is true,
                input will be a chat history otherwise a prompt
            reference (str):
                reference answers, given by user
            output (str):
                output answers from llm, generated by llm

        Returns:
            Dict[str, Any]: evaluate result in json schema
        """
        if isinstance(input, list):
            if not isinstance(self.model, ChatCompletion):
                raise ValueError("model is not an instance of ChatCompletion")
            if len(input) != 1 and len(input) % 2 != 0:
                raise ValueError("chat history is not single text or dialogs")
            # 生成评价模板
            input_content = (
                input[0].get("content", "") if len(input) == 1 else json.dumps(input)
            )
            prompt_text, _ = self.evaluation_prompt.render(
                criteria=self.criteria,
                steps=self.prompt_steps,
                max_score=str(self.prompt_max_score),
                prompt=input_content,
                reference=reference,
                response=reference,
            )
            # 调用模型获得评分
            msg = qianfan.Messages()
            msg.append(prompt_text)

            resp = self.model.do(
                messages=msg,
                **self.model_kwargs,
            )
            assert isinstance(resp, QfResponse)
            result = resp["result"].strip()
            return {self.metric_name: result}
        elif isinstance(input, str):
            # 生成评价模板
            prompt, _ = self.evaluation_prompt.render(
                criteria=self.criteria,
                steps=self.prompt_steps,
                max_score=str(self.prompt_max_score),
                prompt=input,
                reference=reference,
                response=reference,
            )
            if isinstance(self.model, Completion):
                resp = self.model.do(
                    prompt=prompt,
                    **self.model_kwargs,
                )
            elif isinstance(self.model, ChatCompletion):
                msg = qianfan.Messages()
                msg.append(prompt)
                resp = self.model.do(
                    messages=msg,
                    **self.model_kwargs,
                )
            else:
                raise ValueError("Unsupported model type")
            assert isinstance(resp, QfResponse)
            result = resp["result"].strip()
            return {self.metric_name: result}
        else:
            raise ValueError(f"input in {type(input)} not supported")
